adaptive-harmony 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/__init__.py +162 -0
- adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony/core/display.py +93 -0
- adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony/py.typed +0 -0
- adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
- adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
- adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
- adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# ruff: noqa: F403, F401
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from harmony_client import (
|
|
5
|
+
EvalSample as EvalSample,
|
|
6
|
+
)
|
|
7
|
+
from harmony_client import (
|
|
8
|
+
EvalSampleInteraction as EvalSampleInteraction,
|
|
9
|
+
)
|
|
10
|
+
from harmony_client import (
|
|
11
|
+
Grade as Grade,
|
|
12
|
+
)
|
|
13
|
+
from harmony_client import (
|
|
14
|
+
HarmonyClient as HarmonyClient,
|
|
15
|
+
)
|
|
16
|
+
from harmony_client import (
|
|
17
|
+
HarmonyJobNotifier as HarmonyJobNotifier,
|
|
18
|
+
)
|
|
19
|
+
from harmony_client import (
|
|
20
|
+
InferenceModel as InferenceModel,
|
|
21
|
+
)
|
|
22
|
+
from harmony_client import (
|
|
23
|
+
JobArtifact as JobArtifact,
|
|
24
|
+
)
|
|
25
|
+
from harmony_client import (
|
|
26
|
+
JobNotifier as JobNotifier,
|
|
27
|
+
)
|
|
28
|
+
from harmony_client import (
|
|
29
|
+
ModelBuilder as ModelBuilder,
|
|
30
|
+
)
|
|
31
|
+
from harmony_client import (
|
|
32
|
+
StageNotifier as StageNotifier,
|
|
33
|
+
)
|
|
34
|
+
from harmony_client import (
|
|
35
|
+
StringThread as StringThread,
|
|
36
|
+
)
|
|
37
|
+
from harmony_client import (
|
|
38
|
+
TokenizedThread as TokenizedThread,
|
|
39
|
+
)
|
|
40
|
+
from harmony_client import (
|
|
41
|
+
TrainingModel as TrainingModel,
|
|
42
|
+
)
|
|
43
|
+
from harmony_client import (
|
|
44
|
+
get_client as get_client,
|
|
45
|
+
)
|
|
46
|
+
from harmony_client import parameters as parameters
|
|
47
|
+
from harmony_client import runtime as runtime
|
|
48
|
+
from rich.progress import Progress
|
|
49
|
+
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from harmony_client import StringTurn as StringTurn
|
|
52
|
+
else:
|
|
53
|
+
from typing import NamedTuple
|
|
54
|
+
|
|
55
|
+
class StringTurn(NamedTuple):
|
|
56
|
+
role: str
|
|
57
|
+
content: str
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
from harmony_client.artifacts.custom_artifact import CustomArtifact
|
|
61
|
+
from harmony_client.artifacts.dataset_artifact import DatasetArtifact
|
|
62
|
+
from harmony_client.file_storage import (
|
|
63
|
+
FileStorage,
|
|
64
|
+
FileStorageConfig,
|
|
65
|
+
LocalFileStorageConfig,
|
|
66
|
+
S3FileStorageConfig,
|
|
67
|
+
StoredFile,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
import adaptive_harmony.core.rl_utils as rl_utils
|
|
71
|
+
from adaptive_harmony.core.dataset import DataSet
|
|
72
|
+
from adaptive_harmony.core.schedulers import CombinedSchedule, CosineScheduler, CosineSchedulerWithoutWarmup, Scheduler
|
|
73
|
+
from adaptive_harmony.evaluation.evaluation_artifact import EvaluationArtifact
|
|
74
|
+
from adaptive_harmony.metric_logger import Logger, WandbLogger
|
|
75
|
+
|
|
76
|
+
# Ensure key classes are available at module level
|
|
77
|
+
__all__ = [
|
|
78
|
+
"StringThread",
|
|
79
|
+
"StringTurn",
|
|
80
|
+
"TokenizedThread",
|
|
81
|
+
"InferenceModel",
|
|
82
|
+
"ModelBuilder",
|
|
83
|
+
"TrainingModel",
|
|
84
|
+
"HarmonyClient",
|
|
85
|
+
"get_client",
|
|
86
|
+
"DataSet",
|
|
87
|
+
"CosineScheduler",
|
|
88
|
+
"CombinedSchedule",
|
|
89
|
+
"CosineSchedulerWithoutWarmup",
|
|
90
|
+
"Scheduler",
|
|
91
|
+
"WandbLogger",
|
|
92
|
+
"Logger",
|
|
93
|
+
"FileStorage",
|
|
94
|
+
"FileStorageConfig",
|
|
95
|
+
"LocalFileStorageConfig",
|
|
96
|
+
"S3FileStorageConfig",
|
|
97
|
+
"StoredFile",
|
|
98
|
+
"EvaluationArtifact",
|
|
99
|
+
"CustomArtifact",
|
|
100
|
+
"DatasetArtifact",
|
|
101
|
+
"rl_utils",
|
|
102
|
+
"Grade",
|
|
103
|
+
"EvalSample",
|
|
104
|
+
"EvalSampleInteraction",
|
|
105
|
+
"JobArtifact",
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# Patch StringThread to use rich for display
|
|
110
|
+
from harmony_client.runtime.model_artifact_save import save_with_artifact
|
|
111
|
+
|
|
112
|
+
from adaptive_harmony.core.display import _stringthread_repr, _tokenizedthread_repr
|
|
113
|
+
from adaptive_harmony.core.image_utils import string_thread_to_html_string
|
|
114
|
+
|
|
115
|
+
# Patch InferenceModel to have json output capabilities
|
|
116
|
+
from adaptive_harmony.core.structured_output import generate_and_validate, render_pydantic_model, render_schema
|
|
117
|
+
|
|
118
|
+
StringThread.__repr__ = _stringthread_repr # type: ignore
|
|
119
|
+
TokenizedThread.__repr__ = _tokenizedthread_repr # type: ignore
|
|
120
|
+
setattr(StringThread, "_repr_html_", string_thread_to_html_string)
|
|
121
|
+
setattr(InferenceModel, "generate_and_validate", generate_and_validate)
|
|
122
|
+
setattr(InferenceModel, "render_schema", staticmethod(render_schema))
|
|
123
|
+
setattr(InferenceModel, "render_pydantic_model", staticmethod(render_pydantic_model))
|
|
124
|
+
|
|
125
|
+
_original_training_model_save = TrainingModel.save
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def _save_with_artifact_wrapper(model: TrainingModel, model_name: str, inference_only: bool = True, ctx=None):
|
|
129
|
+
return await save_with_artifact(model, model_name, inference_only, ctx, _original_training_model_save)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
setattr(TrainingModel, "save", _save_with_artifact_wrapper)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
async def spawn_train(self: ModelBuilder, name: str, max_batch_size: int) -> TrainingModel:
|
|
136
|
+
fut = await self.spawn_train_with_progress(name, max_batch_size) # type:ignore
|
|
137
|
+
|
|
138
|
+
with Progress() as pbar:
|
|
139
|
+
task = pbar.add_task("Loading model", total=1000)
|
|
140
|
+
|
|
141
|
+
while (prog := await fut._await_progress()) != 1.0:
|
|
142
|
+
pbar.update(task, completed=prog, total=1.0)
|
|
143
|
+
pbar.update(task, completed=1.0, total=1.0)
|
|
144
|
+
|
|
145
|
+
return await fut.get()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
async def spawn_inference(self: ModelBuilder, name: str) -> InferenceModel:
|
|
149
|
+
fut = await self.spawn_inference_with_progress(name) # type:ignore
|
|
150
|
+
|
|
151
|
+
with Progress() as pbar:
|
|
152
|
+
task = pbar.add_task("Loading model", total=1000)
|
|
153
|
+
|
|
154
|
+
while (prog := await fut._await_progress()) != 1.0:
|
|
155
|
+
pbar.update(task, completed=prog, total=1.0)
|
|
156
|
+
pbar.update(task, completed=1.0, total=1.0)
|
|
157
|
+
|
|
158
|
+
return await fut.get()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
setattr(ModelBuilder, "spawn_inference", spawn_inference)
|
|
162
|
+
setattr(ModelBuilder, "spawn_train", spawn_train)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from .callbacks import (
|
|
2
|
+
CheckpointCallback as CheckpointCallback,
|
|
3
|
+
)
|
|
4
|
+
from .callbacks import (
|
|
5
|
+
EnvironmentValidationCallback as EnvironmentValidationCallback,
|
|
6
|
+
)
|
|
7
|
+
from .callbacks import (
|
|
8
|
+
GenerateSamplesCallback as GenerateSamplesCallback,
|
|
9
|
+
)
|
|
10
|
+
from .callbacks import (
|
|
11
|
+
GraderEvalCallback as GraderEvalCallback,
|
|
12
|
+
)
|
|
13
|
+
from .callbacks import (
|
|
14
|
+
RecipeCallback as RecipeCallback,
|
|
15
|
+
)
|
|
16
|
+
from .callbacks import (
|
|
17
|
+
ValidationLossCallback as ValidationLossCallback,
|
|
18
|
+
)
|
|
19
|
+
from .dpo import DPO as DPO
|
|
20
|
+
from .env_grpo import ENVGRPO
|
|
21
|
+
from .grpo import GRPO as GRPO
|
|
22
|
+
from .gspo import GSPO as GSPO
|
|
23
|
+
from .ppo import PPO as PPO
|
|
24
|
+
from .rm import RewardModelling as RewardModelling
|
|
25
|
+
from .sft import SFT as SFT
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"SFT",
|
|
29
|
+
"PPO",
|
|
30
|
+
"GRPO",
|
|
31
|
+
"ENVGRPO",
|
|
32
|
+
"DPO",
|
|
33
|
+
"RewardModelling",
|
|
34
|
+
"RecipeCallback",
|
|
35
|
+
"GenerateSamplesCallback",
|
|
36
|
+
"ValidationLossCallback",
|
|
37
|
+
"CheckpointCallback",
|
|
38
|
+
"GraderEvalCallback",
|
|
39
|
+
"EnvironmentValidationCallback",
|
|
40
|
+
]
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from harmony_client import (
|
|
6
|
+
InferenceModel,
|
|
7
|
+
StringThread,
|
|
8
|
+
TrainingModel,
|
|
9
|
+
)
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from adaptive_harmony.core.utils import async_map, async_map_fallible
|
|
13
|
+
from adaptive_harmony.environment import EnvironmentFactory
|
|
14
|
+
from adaptive_harmony.graders import BaseGrader
|
|
15
|
+
from adaptive_harmony.logging_table import Table
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RecipeCallback:
|
|
19
|
+
def __init__(self, frequency: float, log_key_prefix: str | None = None):
|
|
20
|
+
self.frequency = frequency
|
|
21
|
+
self.last_call = -1.0
|
|
22
|
+
self.log_key_prefix = log_key_prefix
|
|
23
|
+
|
|
24
|
+
async def maybe_call(self, current_percentage: float) -> dict[str, Any]:
|
|
25
|
+
if current_percentage - self.last_call >= self.frequency:
|
|
26
|
+
self.last_call = current_percentage
|
|
27
|
+
callback_dict = await self.callback(current_percentage)
|
|
28
|
+
prefixed_dict = {
|
|
29
|
+
(f"{self.log_key_prefix}/{key}" if self.log_key_prefix else key): value
|
|
30
|
+
for key, value in callback_dict.items()
|
|
31
|
+
}
|
|
32
|
+
return prefixed_dict
|
|
33
|
+
return {}
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
async def callback(self, current_percentage: float) -> dict[str, Any]: ...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class GenerateSamplesCallback(RecipeCallback):
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
thread_set: list[StringThread],
|
|
43
|
+
model: InferenceModel,
|
|
44
|
+
frequency: float,
|
|
45
|
+
log_key: str = "samples",
|
|
46
|
+
):
|
|
47
|
+
super().__init__(frequency, log_key_prefix="generation")
|
|
48
|
+
self.thread_set = thread_set
|
|
49
|
+
self.model = model
|
|
50
|
+
self.log_key = log_key
|
|
51
|
+
|
|
52
|
+
async def callback(self, current_percentage: float) -> dict[str, Any]:
|
|
53
|
+
logger.info("Entering generation callback...")
|
|
54
|
+
generation_tokens = await async_map_fallible(self.model.generate_tokens, self.thread_set)
|
|
55
|
+
generation_results = await async_map_fallible(self.model.detokenize_thread, generation_tokens)
|
|
56
|
+
gen_lengths = [sample.len_last_turn() for sample in generation_tokens]
|
|
57
|
+
|
|
58
|
+
generation_logs = {
|
|
59
|
+
self.log_key: Table()
|
|
60
|
+
.add_column(
|
|
61
|
+
"system",
|
|
62
|
+
[
|
|
63
|
+
sample.get_turns()[0].content if sample.get_turns()[0].role == "system" else ""
|
|
64
|
+
for sample in generation_results
|
|
65
|
+
],
|
|
66
|
+
)
|
|
67
|
+
.add_column(
|
|
68
|
+
"prompt",
|
|
69
|
+
[
|
|
70
|
+
repr(
|
|
71
|
+
StringThread(
|
|
72
|
+
sample.get_turns()[1:-1]
|
|
73
|
+
if (sample.get_turns() and sample.get_turns()[0].role == "system")
|
|
74
|
+
else sample.get_turns()[:-1]
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
for sample in generation_results
|
|
78
|
+
],
|
|
79
|
+
)
|
|
80
|
+
.add_column("response", [response.last_content() for response in generation_results]),
|
|
81
|
+
"generation_length_mean": np.mean(gen_lengths).item(),
|
|
82
|
+
"generation_length_std": np.std(gen_lengths).item(),
|
|
83
|
+
"num_samples": len(generation_results),
|
|
84
|
+
}
|
|
85
|
+
return generation_logs
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ValidationLossCallback(RecipeCallback):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
validation_set: list[StringThread],
|
|
92
|
+
model: InferenceModel,
|
|
93
|
+
frequency: float = 0.1,
|
|
94
|
+
log_key: str = "loss",
|
|
95
|
+
):
|
|
96
|
+
super().__init__(frequency, log_key_prefix="validation")
|
|
97
|
+
self.validation_set = validation_set
|
|
98
|
+
self.model = model
|
|
99
|
+
self.log_key = log_key
|
|
100
|
+
|
|
101
|
+
async def callback(self, current_percentage: float) -> dict[str, float]:
|
|
102
|
+
logger.info("Entering validation loss callback...")
|
|
103
|
+
losses = []
|
|
104
|
+
tokens = await async_map_fallible(self.model.tokenize_thread, self.validation_set)
|
|
105
|
+
logprobs = await async_map(self.model.logprobs_per_token, tokens)
|
|
106
|
+
losses = [-(sum(lp) / len(lp)) for lp in logprobs]
|
|
107
|
+
|
|
108
|
+
return {self.log_key: sum(losses) / len(losses)}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class CheckpointCallback(RecipeCallback):
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
model: TrainingModel,
|
|
115
|
+
checkpoint_name: str,
|
|
116
|
+
frequency: float = 0.2,
|
|
117
|
+
):
|
|
118
|
+
super().__init__(frequency, log_key_prefix="checkpointing")
|
|
119
|
+
self.last_call = 0.0 # avoid saving the model at the first period
|
|
120
|
+
self.model = model
|
|
121
|
+
self.model_log_name = checkpoint_name
|
|
122
|
+
|
|
123
|
+
async def callback(self, current_percentage: float):
|
|
124
|
+
logger.info(f"Saving checkpoint at {current_percentage * 100} % of training ...")
|
|
125
|
+
await self.model.save(f"{self.model_log_name}-{round(current_percentage, 3)}")
|
|
126
|
+
return {}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class GraderEvalCallback(RecipeCallback):
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
validation_set: list[StringThread],
|
|
133
|
+
model: InferenceModel,
|
|
134
|
+
grader: BaseGrader,
|
|
135
|
+
frequency: float,
|
|
136
|
+
log_key: str = "validation",
|
|
137
|
+
clear_grader_logs: bool = True,
|
|
138
|
+
temperature: float = 0.0,
|
|
139
|
+
):
|
|
140
|
+
super().__init__(frequency, log_key_prefix=log_key)
|
|
141
|
+
self.validation_set = validation_set
|
|
142
|
+
self.model = model
|
|
143
|
+
self.grader = grader
|
|
144
|
+
self.clear_grader_logs = clear_grader_logs
|
|
145
|
+
self.temperature = temperature
|
|
146
|
+
|
|
147
|
+
async def callback(self, current_percentage: float) -> dict[str, float | Table]:
|
|
148
|
+
logger.info("Entering grader evaluation callback...")
|
|
149
|
+
temp_model = self.model.temperature(self.temperature)
|
|
150
|
+
|
|
151
|
+
tokenized_results = await async_map_fallible(temp_model.generate_tokens, self.validation_set)
|
|
152
|
+
string_results = await async_map(temp_model.detokenize_thread, tokenized_results)
|
|
153
|
+
grades = await async_map_fallible(self.grader.grade, string_results)
|
|
154
|
+
gen_lengths = [sample.len_last_turn() for sample in tokenized_results]
|
|
155
|
+
|
|
156
|
+
grader_logs = self.grader.get_logs(clear=self.clear_grader_logs)
|
|
157
|
+
return {
|
|
158
|
+
**{f"rewards/{key}": value for key, value in grader_logs.items()},
|
|
159
|
+
"generation_length_mean": float(np.mean(gen_lengths).item()),
|
|
160
|
+
"generation_length_std": float(np.std(gen_lengths).item()),
|
|
161
|
+
"num_samples": float(len(grades)),
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class EnvironmentValidationCallback(RecipeCallback):
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
validation_set: list[StringThread],
|
|
169
|
+
model: InferenceModel,
|
|
170
|
+
env_factory: EnvironmentFactory,
|
|
171
|
+
frequency: float,
|
|
172
|
+
log_key: str = "validation",
|
|
173
|
+
clear_env_logs: bool = True,
|
|
174
|
+
temperature: float = 0.0,
|
|
175
|
+
num_samples_log: int = 0,
|
|
176
|
+
):
|
|
177
|
+
super().__init__(frequency, log_key_prefix=log_key)
|
|
178
|
+
self.validation_set = validation_set
|
|
179
|
+
self.model = model
|
|
180
|
+
self.env_factory = env_factory
|
|
181
|
+
self.clear_env_logs = clear_env_logs
|
|
182
|
+
self.temperature = temperature
|
|
183
|
+
self.num_samples_log = num_samples_log
|
|
184
|
+
|
|
185
|
+
async def generate_trajectory(self, initial_thread: StringThread) -> tuple[StringThread, float, int]:
|
|
186
|
+
env = self.env_factory.create_environment(initial_thread.metadata)
|
|
187
|
+
temp_model = self.model.temperature(self.temperature)
|
|
188
|
+
trajectory, trajectory_score = await env.generate_trajectory_and_grade(temp_model, initial_thread)
|
|
189
|
+
num_turns = len([turn for turn in trajectory.get_turns() if turn.role == "assistant"])
|
|
190
|
+
return trajectory, trajectory_score.cumulative_score, num_turns
|
|
191
|
+
|
|
192
|
+
async def callback(self, current_percentage: float) -> dict[str, float | Table]:
|
|
193
|
+
logger.info("Entering environment validation callback...")
|
|
194
|
+
|
|
195
|
+
results = await async_map_fallible(self.generate_trajectory, self.validation_set)
|
|
196
|
+
|
|
197
|
+
trajectories = [traj for traj, _, _ in results]
|
|
198
|
+
scores = [score for _, score, _ in results]
|
|
199
|
+
num_turns_list = [num_turns for _, _, num_turns in results]
|
|
200
|
+
|
|
201
|
+
validation_logs = {
|
|
202
|
+
"score_mean": np.mean(scores).item(),
|
|
203
|
+
"score_std": np.std(scores).item(),
|
|
204
|
+
"num_turns_mean": np.mean(num_turns_list).item(),
|
|
205
|
+
"num_turns_std": np.std(num_turns_list).item(),
|
|
206
|
+
"num_samples": len(results),
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
env_logs = self.env_factory.get_logs(clear=self.clear_env_logs)
|
|
210
|
+
validation_logs.update({f"env/{key}": value for key, value in env_logs.items()})
|
|
211
|
+
|
|
212
|
+
if self.num_samples_log > 0:
|
|
213
|
+
samples = [repr(traj) for traj in trajectories[: self.num_samples_log]]
|
|
214
|
+
samples_scores = scores[: self.num_samples_log]
|
|
215
|
+
table = Table().add_column("trajectory", samples).add_column("score", samples_scores)
|
|
216
|
+
validation_logs["samples"] = table
|
|
217
|
+
|
|
218
|
+
logger.info(f"Validation Mean score: {validation_logs['score_mean']:.4f}")
|
|
219
|
+
return validation_logs
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Awaitable, Callable, Sequence
|
|
5
|
+
|
|
6
|
+
import anyio
|
|
7
|
+
import numpy as np
|
|
8
|
+
from loguru import logger as loguru
|
|
9
|
+
|
|
10
|
+
from adaptive_harmony import DataSet, StringThread
|
|
11
|
+
from adaptive_harmony.common.callbacks import RecipeCallback
|
|
12
|
+
from adaptive_harmony.core.utils import hash_dataset
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CheckpointManager:
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
recipe_name: str,
|
|
19
|
+
dataset: DataSet,
|
|
20
|
+
threads_dataset: Sequence[StringThread],
|
|
21
|
+
callbacks: Sequence[RecipeCallback],
|
|
22
|
+
hyperparams_hash: str,
|
|
23
|
+
job_id: str | None = None,
|
|
24
|
+
checkpoint_frequency: float = 0.2,
|
|
25
|
+
restart_from_checkpoint: str | None = None,
|
|
26
|
+
):
|
|
27
|
+
self.recipe_name = recipe_name
|
|
28
|
+
self.dataset = dataset
|
|
29
|
+
self.dataset_hash = hash_dataset(threads_dataset)
|
|
30
|
+
self.hyperparams_hash = hyperparams_hash
|
|
31
|
+
self.callbacks = callbacks
|
|
32
|
+
self.checkpoint_frequency = checkpoint_frequency
|
|
33
|
+
self.last_checkpoint_percentage = 0.0
|
|
34
|
+
self.restart_from_checkpoint = restart_from_checkpoint
|
|
35
|
+
self.job_id = job_id
|
|
36
|
+
self.checkpointing_folder = self._init_folder()
|
|
37
|
+
|
|
38
|
+
def _init_folder(self) -> str | None:
|
|
39
|
+
if self.job_id is None or os.getenv("HARMONY_NO_CHECKPOINTING") is not None:
|
|
40
|
+
loguru.warning("Checkpointing is disabled for this recipe.")
|
|
41
|
+
return None
|
|
42
|
+
return os.path.join(os.getenv("RECIPE_CHECKPOINTS_DIR", "/checkpoints"), self.job_id)
|
|
43
|
+
|
|
44
|
+
async def maybe_restore_checkpoint(
|
|
45
|
+
self,
|
|
46
|
+
recipe_specific_checkpoint_loading: Callable[[dict], Awaitable[None]],
|
|
47
|
+
) -> None:
|
|
48
|
+
if self.restart_from_checkpoint is None:
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
checkpoint_path = Path(self.restart_from_checkpoint)
|
|
52
|
+
checkpoint_file = self._resolve_checkpoint_file(checkpoint_path)
|
|
53
|
+
|
|
54
|
+
assert checkpoint_file, f"Checkpoint file not found: {checkpoint_path}."
|
|
55
|
+
|
|
56
|
+
loguru.info(f"Loading {self.recipe_name} checkpoint from: {checkpoint_file}")
|
|
57
|
+
|
|
58
|
+
contents = ""
|
|
59
|
+
async with await anyio.open_file(checkpoint_file, "r") as f:
|
|
60
|
+
contents = await f.read()
|
|
61
|
+
checkpoint_data = json.loads(contents)
|
|
62
|
+
|
|
63
|
+
assert checkpoint_data.get("recipe_type") == self.recipe_name, (
|
|
64
|
+
f"Recipe type mismatch: checkpoint is '{checkpoint_data.get('recipe_type')}', "
|
|
65
|
+
f"but trying to load into {self.recipe_name}"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
assert checkpoint_data.get("dataset_hash") == self.dataset_hash, (
|
|
69
|
+
"Dataset hash mismatch between checkpoint and current dataset."
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
assert checkpoint_data.get("hyperparams_hash") == self.hyperparams_hash, (
|
|
73
|
+
"Hyperparameters hash mismatch between checkpoint and current recipe configuration."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.dataset.idx = checkpoint_data.get("dataset_idx", 0)
|
|
77
|
+
|
|
78
|
+
access_indices_list = checkpoint_data.get("dataset_access_indices", [])
|
|
79
|
+
if access_indices_list:
|
|
80
|
+
self.dataset.access_indices = np.array(access_indices_list)
|
|
81
|
+
|
|
82
|
+
rng_state = checkpoint_data.get("dataset_rng_state")
|
|
83
|
+
if rng_state:
|
|
84
|
+
self.dataset.rng.bit_generator.state = rng_state
|
|
85
|
+
|
|
86
|
+
callback_states = checkpoint_data.get("callback_last_calls", [])
|
|
87
|
+
assert len(callback_states) == len(self.callbacks), "Mismatch in number of callbacks when loading checkpoint"
|
|
88
|
+
for i, callback in enumerate(self.callbacks):
|
|
89
|
+
callback.last_call = callback_states[i]
|
|
90
|
+
|
|
91
|
+
await recipe_specific_checkpoint_loading(checkpoint_data)
|
|
92
|
+
|
|
93
|
+
self.last_checkpoint_percentage = checkpoint_data.get("completion_percentage", 0.0)
|
|
94
|
+
|
|
95
|
+
loguru.info(f"Checkpoint restored: starting {self.recipe_name} from {self.last_checkpoint_percentage:.2%}.")
|
|
96
|
+
|
|
97
|
+
async def maybe_checkpoint(
|
|
98
|
+
self,
|
|
99
|
+
completion_percentage: float,
|
|
100
|
+
recipe_specific_checkpoint_saving: Callable[[], Awaitable[dict]],
|
|
101
|
+
) -> bool:
|
|
102
|
+
if self.checkpointing_folder is None:
|
|
103
|
+
return False
|
|
104
|
+
|
|
105
|
+
if completion_percentage >= 1.0:
|
|
106
|
+
return False
|
|
107
|
+
|
|
108
|
+
if await self._check_graceful_exit_file():
|
|
109
|
+
loguru.info(f"Graceful exit requested. Saving checkpoint and exiting {self.recipe_name} training loop.")
|
|
110
|
+
await self._save_checkpoint(completion_percentage, recipe_specific_checkpoint_saving)
|
|
111
|
+
return True
|
|
112
|
+
|
|
113
|
+
if completion_percentage - self.last_checkpoint_percentage >= self.checkpoint_frequency:
|
|
114
|
+
await self._save_checkpoint(completion_percentage, recipe_specific_checkpoint_saving)
|
|
115
|
+
self.last_checkpoint_percentage = completion_percentage
|
|
116
|
+
|
|
117
|
+
return False
|
|
118
|
+
|
|
119
|
+
async def _save_checkpoint(
|
|
120
|
+
self,
|
|
121
|
+
completion_percentage: float,
|
|
122
|
+
get_save_config: Callable[[], Awaitable[dict]],
|
|
123
|
+
) -> None:
|
|
124
|
+
assert self.checkpointing_folder is not None # will never be called outside of this condition
|
|
125
|
+
progress_pct = int(completion_percentage * 100)
|
|
126
|
+
checkpoint_dir = Path(self.checkpointing_folder)
|
|
127
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
128
|
+
|
|
129
|
+
loguru.info(f"Checkpointing {self.recipe_name} at {checkpoint_dir} ({progress_pct}%)...")
|
|
130
|
+
|
|
131
|
+
recipe_data = await get_save_config()
|
|
132
|
+
|
|
133
|
+
checkpoint_data = {
|
|
134
|
+
"recipe_type": self.recipe_name,
|
|
135
|
+
"dataset_hash": self.dataset_hash,
|
|
136
|
+
"hyperparams_hash": self.hyperparams_hash,
|
|
137
|
+
"dataset_idx": self.dataset.idx,
|
|
138
|
+
"dataset_access_indices": self.dataset.access_indices.tolist(),
|
|
139
|
+
"dataset_rng_state": self.dataset.rng.bit_generator.state,
|
|
140
|
+
"callback_last_calls": [callback.last_call for callback in self.callbacks],
|
|
141
|
+
"completion_percentage": completion_percentage,
|
|
142
|
+
**recipe_data,
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
checkpoint_file = checkpoint_dir / f"checkpoint-{progress_pct}.json"
|
|
146
|
+
|
|
147
|
+
data_dump = json.dumps(checkpoint_data, indent=2)
|
|
148
|
+
async with await anyio.open_file(checkpoint_file, "w") as f:
|
|
149
|
+
await f.write(data_dump)
|
|
150
|
+
|
|
151
|
+
loguru.info(f"Checkpoint saved: {checkpoint_file}")
|
|
152
|
+
|
|
153
|
+
async def _check_graceful_exit_file(self) -> bool:
|
|
154
|
+
if self.checkpointing_folder is None:
|
|
155
|
+
return False
|
|
156
|
+
return (Path(self.checkpointing_folder) / "GRACEFUL_EXIT").exists()
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def _resolve_checkpoint_file(path: Path) -> Path | None:
|
|
160
|
+
if path.is_dir():
|
|
161
|
+
files = sorted(path.glob("checkpoint-*.json"), key=lambda p: int(p.stem.split("-")[1]))
|
|
162
|
+
return files[-1] if files else None
|
|
163
|
+
return path if path.exists() else None
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Callable, Sequence
|
|
2
|
+
|
|
3
|
+
from tqdm.auto import tqdm
|
|
4
|
+
|
|
5
|
+
from adaptive_harmony import (
|
|
6
|
+
CosineScheduler,
|
|
7
|
+
DataSet,
|
|
8
|
+
JobNotifier,
|
|
9
|
+
Logger,
|
|
10
|
+
StageNotifier,
|
|
11
|
+
StringThread,
|
|
12
|
+
TrainingModel,
|
|
13
|
+
)
|
|
14
|
+
from adaptive_harmony.common.callbacks import RecipeCallback
|
|
15
|
+
from adaptive_harmony.core.utils import async_map_batch, log_args
|
|
16
|
+
from adaptive_harmony.metric_logger import StdoutLogger
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DPO:
|
|
20
|
+
@log_args
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
dataset: list[tuple[StringThread, StringThread]], # (positive_sample, negative_sample)
|
|
24
|
+
model: TrainingModel,
|
|
25
|
+
logger: Logger = StdoutLogger(),
|
|
26
|
+
stage_notifier: StageNotifier = JobNotifier().stage_notifier("DPO Training"),
|
|
27
|
+
callbacks: Sequence[RecipeCallback] = [],
|
|
28
|
+
lr: float = 1e-6,
|
|
29
|
+
lr_scheduler: Callable[[float], float] | None = None,
|
|
30
|
+
samples_per_batch=32,
|
|
31
|
+
max_grad_norm=1.0,
|
|
32
|
+
kl_beta=0.1,
|
|
33
|
+
epochs=1,
|
|
34
|
+
skip_nan_gradients: bool = False,
|
|
35
|
+
):
|
|
36
|
+
# Core components
|
|
37
|
+
self.model_ref = None
|
|
38
|
+
self.dataset = DataSet(dataset)
|
|
39
|
+
self.model = model
|
|
40
|
+
self.logger = logger
|
|
41
|
+
self.stage_notifier = stage_notifier
|
|
42
|
+
self.callbacks = callbacks
|
|
43
|
+
self.lr_schedule = lr_scheduler or CosineScheduler(lr)
|
|
44
|
+
self.samples_per_batch = samples_per_batch
|
|
45
|
+
self.max_grad_norm = max_grad_norm
|
|
46
|
+
self.skip_nan_gradients = skip_nan_gradients
|
|
47
|
+
|
|
48
|
+
# DPO HP's
|
|
49
|
+
self.kl_beta = kl_beta
|
|
50
|
+
self.epochs = epochs
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def training_completion_percentage(self):
|
|
54
|
+
return self.dataset.completion_percentage() / self.epochs
|
|
55
|
+
|
|
56
|
+
async def process_sample(self, sample: tuple[StringThread, StringThread]):
|
|
57
|
+
assert self.model_ref is not None, "Calling `process_sample_dpo` before reference model has been set"
|
|
58
|
+
|
|
59
|
+
pos, neg = sample
|
|
60
|
+
ref_logprobs_pos = await self.model_ref.logprobs(pos)
|
|
61
|
+
ref_logprobs_neg = await self.model_ref.logprobs(neg)
|
|
62
|
+
await self.model.train_dpo(pos, neg, ref_logprobs_pos, ref_logprobs_neg, self.kl_beta)
|
|
63
|
+
|
|
64
|
+
async def run(self):
|
|
65
|
+
self.model_ref = await self.model.clone_inf()
|
|
66
|
+
|
|
67
|
+
self.stage_notifier.report_progress(
|
|
68
|
+
tot_num_samples=len(self.dataset) * self.epochs,
|
|
69
|
+
processed_num_samples=self.dataset.idx,
|
|
70
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
with tqdm(total=100) as pbar:
|
|
74
|
+
while self.training_completion_percentage < 1.0:
|
|
75
|
+
for callback in self.callbacks:
|
|
76
|
+
if logs := await callback.maybe_call(self.training_completion_percentage):
|
|
77
|
+
self.logger(logs)
|
|
78
|
+
|
|
79
|
+
await async_map_batch(self.process_sample, self.dataset, self.samples_per_batch)
|
|
80
|
+
cp = self.training_completion_percentage
|
|
81
|
+
current_lr = self.lr_schedule(cp)
|
|
82
|
+
pbar.update(cp * 100.0 - pbar.n)
|
|
83
|
+
logs = await self.model.optim_step(
|
|
84
|
+
current_lr, wd=0, max_grad_norm=self.max_grad_norm, skip_nan_gradients=self.skip_nan_gradients
|
|
85
|
+
)
|
|
86
|
+
self.logger(logs | dict(completion_percentage=cp))
|
|
87
|
+
|
|
88
|
+
self.stage_notifier.report_progress(
|
|
89
|
+
tot_num_samples=len(self.dataset) * self.epochs,
|
|
90
|
+
processed_num_samples=self.dataset.idx,
|
|
91
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
92
|
+
)
|